# Copyright 2021-2023 @ Shenzhen Bay Laboratory &
#                       Peking University &
#                       Huawei Technologies Co., Ltd
#
# This code is a part of MindSPONGE:
# MindSpore Simulation Package tOwards Next Generation molecular modelling.
#
# MindSPONGE is open-source software based on the AI-framework:
# PyTorch (https://pytorch.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base energy cell"""

from typing import Union
import torch
from torch import Tensor
from torch import nn

from ...system.molecule import Molecule
from ...function import get_tensor
from ...function.units import Units, Length, GLOBAL_UNITS

_ENERGY_BY_KEY = {}


def _energy_register(*aliases):
    """Return the alias register."""
    def alias_reg(cls):
        name = cls.__name__
        name = name.lower()
        if name not in _ENERGY_BY_KEY:
            _ENERGY_BY_KEY[name] = cls

        for alias in aliases:
            if alias not in _ENERGY_BY_KEY:
                _ENERGY_BY_KEY[alias] = cls

        return cls

    return alias_reg


class EnergyCell(nn.Module):
    r"""
    Base class for energy terms.
    `EnergyCell` is usually used as a base class for individual energy terms in a classical force field.
    As the force field parameters usually has units, the units of the `EnergyCell` as an energy term
    should be the same as the units of the force field parameters, and not equal to the global units.

    Args:
        name (str):         Name of energy. Default: 'energy'
        length_unit (str):  Length unit. If None is given, it will be assigned with the global length unit.
                            Default: 'nm'
        energy_unit (str):  Energy unit. If None is given, it will be assigned with the global energy unit.
                            Default: 'kj/mol'
        use_pbc (bool):     Whether to use periodic boundary condition. Default: None

    Returns:
        Tensor of energy, Tensor of shape `(B, 1)`. Data type is float.

    Supported Platforms:
        ``CPU`` ``GPU``

    Symbols:
        B:  Batchsize, i.e. number of walkers in simulation
    """
    def __init__(self,
                 name: str = 'energy',
                 length_unit: str = 'nm',
                 energy_unit: str = 'kj/mol',
                 use_pbc: bool = None,
                 **kwargs
                 ):
        super().__init__()
        self._kwargs = kwargs

        self._name = name

        self._use_pbc = use_pbc

        if length_unit is None:
            length_unit = GLOBAL_UNITS.length_unit
        if energy_unit is None:
            energy_unit = GLOBAL_UNITS.energy_unit
        self.units = Units(length_unit, energy_unit)

        self.input_unit_scale = get_tensor([1.], dtype=torch.float32)
        self.cutoff = None

    @property
    def name(self) -> str:
        """
        Name of energy.

        Returns:
            str, name of energy.
        """
        return self._name

    @property
    def use_pbc(self) -> bool:
        """
        Whether to use periodic boundary condition.

        Returns:
            bool, the flag used to judge whether to use periodic boundary condition.
        """
        return self._use_pbc

    @property
    def length_unit(self) -> str:
        """
        Length unit.

        Returns:
            str, length unit.
        """
        return self.units.length_unit

    @property
    def energy_unit(self) -> str:
        """
        Energy unit.

        Returns:
            str, energy unit.
        """
        return self.units.energy_unit

    @staticmethod
    def check_system(system: Molecule) -> bool:
        """Check if the system needs to calculate this energy term"""
        #pylint:disable=unused-argument
        return True

    def set_units(self, length_unit: str = None, energy_unit: str = None, units: Units = None):
        r"""set length and energy units"""
        if units is None:
            if length_unit is None:
                length_unit = GLOBAL_UNITS.length_unit
            if energy_unit is None:
                energy_unit = GLOBAL_UNITS.energy_unit
        else:
            length_unit = None
            energy_unit = None

        if self.units is None:
            self.units = Units(length_unit=length_unit, energy_unit=energy_unit, units=units)
        else:
            self.units.set_units(length_unit=length_unit, energy_unit=energy_unit, units=units)

        return self

    def set_input_unit(self, length_unit: Union[str, Units, Length]):
        """
        Set the length unit for the input coordinates.

        Args:
            length_unit(Union[str, Units, Length]): The length unit for the input coordinates.
        """
        if length_unit is None:
            self.input_unit_scale = get_tensor([1.], dtype=torch.float32)
        elif isinstance(length_unit, (str, Units, float)):
            self.input_unit_scale = get_tensor(
                self.units.convert_length_from(length_unit), dtype=torch.float32)
        else:
            raise TypeError(f'Unsupported type of `length_unit`: {type(length_unit)}')

        return self

    def set_cutoff(self, cutoff: float, unit: str = None):
        """
        Set cutoff distances.

        Args:
            cutoff(float):  Cutoff distances.
            unit(str):      Length unit. Default: None
        """
        if cutoff is None:
            self.cutoff = None
        else:
            cutoff = get_tensor(cutoff, torch.float32)
            self.cutoff = self.units.length(cutoff, unit)
        return self

    def set_pbc(self, use_pbc: bool):
        """
        Set whether to use periodic boundary condition.

        Args:
            use_pbc(bool): Whether to use periodic boundary condition.
        """
        self._use_pbc = use_pbc
        return self

    def convert_energy_from(self, unit: str) -> float:
        """
        Convert energy from a unit.

        Args:
            unit(str): Energy unit.

        Returns:
            float, conversion factor.
        """
        return self.units.convert_energy_from(unit)

    def convert_energy_to(self, unit: str) -> float:
        """
        Convert energy to a unit.

        Args:
            unit(str): Energy unit.

        Returns:
            float, conversion factor.
        """
        return self.units.convert_energy_to(unit)

    def forward(self,
                  coordinate: Tensor,
                  neighbour_index: Tensor = None,
                  neighbour_mask: Tensor = None,
                  neighbour_vector: Tensor = None,
                  neighbour_distance: Tensor = None,
                  pbc_box: Tensor = None
                  ):
        r"""Calculate energy.

        Args:
            coordinate (Tensor):            Tensor of shape (B, A, D). Data type is float.
                                            Position coordinate of atoms in system.
            neighbour_index (Tensor):       Tensor of shape (B, A, N). Data type is int.
                                            Index of neighbour atoms. Default: None
            neighbour_mask (Tensor):        Tensor of shape (B, A, N). Data type is bool.
                                            Mask for neighbour atoms. Default: None
            neighbour_vector (Tensor):       Tensor of shape (B, A, N). Data type is bool.
                                            Vectors from central atom to neighbouring atoms.
            neighbour_distance (Tensor):    Tensor of shape (B, A, N). Data type is float.
                                            Distance between neighbours atoms. Default: None
            pbc_box (Tensor):               Tensor of shape (B, D). Data type is float.
                                            Tensor of PBC box. Default: None

        Returns:
            energy (Tensor): Tensor of shape (B, 1). Data type is float.

        Symbols:
            B:  Batchsize, i.e. number of walkers in simulation
            A:  Number of atoms.
            N:  Maximum number of neighbour atoms.
            D:  Spatial dimension of the simulation system. Usually is 3.

        """
        raise NotImplementedError

    def print_info(self) -> dict:
        """Return basic information about the energy term
        
        Returns:
            dict: Dictionary containing energy information
        """
        info = {
            "Energy type": self.name,
            "Energy unit": self.units.energy_unit,
            "Length unit": self.units.length_unit,
            "Cutoff": self.cutoff.item() if self.cutoff is not None else 'No cutoff',
            "Use PBC": self.use_pbc if self.use_pbc is not None else 'Not specified'
        }
        return info


class NonbondEnergy(EnergyCell):
    r"""
    Base class for non-bonded energy terms.

    Args:
        name (str):         Name of energy. Default: 'nonbond'
        cutoff (float):     Cutoff distance. Default: None
        length_unit (str):  Length unit. If None is given, it will be assigned with the global length unit.
                            Default: 'nm'
        energy_unit (str):  Energy unit. If None is given, it will be assigned with the global energy unit.
                            Default: 'kj/mol'
        use_pbc (bool):     Whether to use periodic boundary condition. Default: None

    Returns:
        Tensor of energy, Tensor of shape `(B, 1)`. Data type is float.

    Supported Platforms:
        ``CPU`` ``GPU``

    Symbols:
        B:  Batchsize, i.e. number of walkers in simulation
    """

    def __init__(self,
                 name: str,
                 cutoff: Union[float, Length, Tensor] = None,
                 length_unit: str = 'nm',
                 energy_unit: str = 'kj/mol',
                 use_pbc: bool = None,
                 ):

        super().__init__(
            name=name,
            length_unit=length_unit,
            energy_unit=energy_unit,
            use_pbc=use_pbc,
        )
        
        if isinstance(cutoff, Length):
            cutoff = cutoff(self.units)
        self.set_cutoff(cutoff)

    def forward(self,
                  coordinate: Tensor,
                  neighbour_index: Tensor = None,
                  neighbour_mask: Tensor = None,
                  neighbour_vector: Tensor = None,
                  neighbour_distance: Tensor = None,
                  pbc_box: Tensor = None
                  ):
        r"""Calculate energy.

        Args:
            coordinate (Tensor):            Tensor of shape (B, A, D). Data type is float.
                                            Position coordinate of atoms in system.
            neighbour_index (Tensor):       Tensor of shape (B, A, N). Data type is int.
                                            Index of neighbour atoms. Default: None
            neighbour_mask (Tensor):        Tensor of shape (B, A, N). Data type is bool.
                                            Mask for neighbour atoms. Default: None
            neighbour_vector (Tensor):       Tensor of shape (B, A, N). Data type is bool.
                                            Vectors from central atom to neighbouring atoms.
            neighbour_distance (Tensor):    Tensor of shape (B, A, N). Data type is float.
                                            Distance between neighbours atoms. Default: None
            pbc_box (Tensor):               Tensor of shape (B, D). Data type is float.
                                            Tensor of PBC box. Default: None

        Returns:
            energy (Tensor): Tensor of shape (B, 1). Data type is float.

        Symbols:
            B:  Batchsize, i.e. number of walkers in simulation
            A:  Number of atoms.
            N:  Maximum number of neighbour atoms.
            D:  Spatial dimension of the simulation system. Usually is 3.

        """
        raise NotImplementedError
